# Exemple d'optimalité entre les 3 fct MH, MH_high_dim et MH_high_dim_para sur le modèle non lineaire mixte
m <- function(t, phi1, phi2, phi3) (phi1  )/(1+exp((phi2-t)/phi3))
#=======================================#
p <- 100
parameter <- list(sigma2 = .05^2,
                  #rho2 = 5,
                  mu = c(0.9,90,5),
                  omega2 = c(0.005, 40, 1),
                  #S.data data,
                  bara = 90,
                  barb = 30,
                  baralpha = 0.5,
                  beta = rep(0,p))
parameter$beta[1:4] <- c(-.8, -.2 , .3 , .9)
#=======================================#
t <- seq(60,120, length.out = 10) #time values
set.seed(123)

G <- 40 ; ng = 4
# link = function(t,phi1, phi2, phi3) phi2#/phi3
link = m

dt <- create_JLS_HD_data(G, ng, t, m, link, parameter)

var.true <- dt$var.true
a <- var.true$a ; var.true$a <- NULL #a fixé (et retiré des variables latentes)
S.data <- dt$survival
U <- dt$U

Y <- do.call(get_obs, var.true) + rnorm(n, 0, sqrt(parameter$sigma2))
S.data.time <- S.data$obs
S.data.time.log.sum <- sum(log(S.data.time))
longitudinal_plot <- data.frame(time, Y, id = rep(1:N, each = length(t)), gen = rep(1:G, each = ng*length(t)) ) %>%
  ggplot(aes(time, Y, col = factor(gen), group = factor(id) )) +
  geom_point() + geom_line() +
  theme(legend.position = 'null')

S.data_plot <- S.data %>% ggplot(aes(obs)) +#, fill = U))) +
    geom_histogram(col = 'white', position = 'identity', bins = 30) + theme(legend.position = 'null')

grid.arrange(longitudinal_plot, S.data_plot, nrow = 2)

model <- SAEM_model( 
  function(sigma2, ...) -n/(2*sigma2),
  function(phi1, phi2, phi3, ...) mean((Y - get_obs(phi1, phi2, phi3) )^2 ), 'sigma2',
  
  # === Variable Latente === #
  latent_vars = list(
    # === Non linear model === #
    latent_variable('phi', dim = G, size = 3, prior = list(mean = 'mu', variance = 'omega2'),
                    add_on = c('zeta(phi1 = phi1, phi2 = phi2, phi3 = phi3, ...)' )),
    
    # === S.data model === #
    latent_variable('b', prior = list(mean = 'barb', variance.hyper = 'sigma2_b'),
                    add_on = c('zeta(b = b, ...) +',
                               'sum(h$eval(b = b, ..., i = c(1,2)))' )),
    latent_variable('alpha', prior = list(mean = 'baralpha', variance.hyper = 'sigma2_alpha'),
                    add_on = c('zeta(alpha = alpha, ...) +',
                               'alpha*h$eval(alpha = alpha,..., i = 3)'))
  ),

  # === Paramètre de regression === #
  regression.parameter = list(
    regression_parameter('beta', 1, function(...) SPGD(1, theta0 = beta,
                                                      step = 0.05, lambda = 1/sqrt(N),
                                                      normalized.grad = T,
                                                      zeta.der.B, N, zeta.B, 
                                                      Z$alpha,  Z$phi1, Z$phi2, Z$phi3,Z$b) )
  )
)
# ---  Initialisation des paramètres --- #
parameter0 <- parameter %>% sapply(function(x) x* runif(1, 1.1,1.4))
parameter0$beta <- runif(p, min = -1, max = 1)

#===============================================#
load.SAEM(model)
S.tmp <- do.call(S$eval, var.true)
oracle <- maximisation(1, do.call(S$eval, var.true), parameter, var.true)
#==============================================================================#

init.options <- list(x0 = list(phi = c(1,80,4), b = parameter0$barb, alpha = parameter0$baralpha), 
                     sd = list(phi = c(.05, 1.5, .5), b = 1, alpha = .1) )

SAEM.options <- list(niter = 200, sim.iter = 5, burnin = 190, 
                adptative.sd = 0.6)

saem

res <- run(model, parameter0, init.options, SAEM.options, verbatim = 3)
saveRDS(res, paste0(params$rds_filename, '_', p, '.rds'))

# = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = #

plot(res, true.value = oracle, exclude = 'beta')
## [1] "SAEM execution time = 00h 32min 37sec"
## $plot_parameter

## 
## $plot_MCMC

## 
## $plot_acceptation

plot_high_dim(res, oracle, 'beta', zeta, dec = 0, 
              var.true$alpha, var.true$phi1, var.true$phi2, var.true$phi3, var.true$b)
## [[1]]

## 
## [[2]]

## 
## [[3]]

plot(res, true.value = oracle, var = 'summary', exclude = 'beta', time = F)
Result of the SAEM-MCMC
sigma2 mu.1 mu.2 mu.3 omega2.1 omega2.2 omega2.3 barb baralpha
Real value 0.0026 0.9032 89.9575 5.0079 0.0039 35.9179 0.6948 30.0000 0.5000
Estimated value 0.0026 0.9003 89.9695 4.9188 0.0045 35.1034 0.6817 37.0008 2.2036
Rrmse 0.0036 0.0032 0.0001 0.0178 0.1333 0.0227 0.0188 0.2334 3.4072

saem

load.options <- list(exclude.maximisation = c('baralpha') )
parameter0[load.options$exclude.maximisation] <- parameter[load.options$exclude.maximisation]

res <- run(model, parameter0, init.options, SAEM.options,load.options, verbatim = 3)
saveRDS(res, paste0(params$rds_filename, '_', p, '.rds'))

# = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = #

plot(res, true.value = oracle, exclude = 'beta')
## [1] "SAEM execution time = 00h 31min 39sec"
## $plot_parameter

## 
## $plot_MCMC

## 
## $plot_acceptation

plot_high_dim(res, oracle, 'beta', zeta, dec = 0, 
              var.true$alpha, var.true$phi1, var.true$phi2, var.true$phi3, var.true$b)
## [[1]]

## 
## [[2]]

## 
## [[3]]

plot(res, true.value = oracle, var = 'summary', exclude = 'beta', time = F)
Result of the SAEM-MCMC
sigma2 mu.1 mu.2 mu.3 omega2.1 omega2.2 omega2.3 barb baralpha
Real value 0.0026 0.9032 89.9575 5.0079 0.0039 35.9179 0.6948 30.0000 0.5
Estimated value 0.0026 0.9016 89.9684 5.0528 0.0040 35.6588 0.8133 38.4418 0.5
Rrmse 0.0041 0.0018 0.0001 0.0090 0.0089 0.0072 0.1707 0.2814 0.0

saem

load.options <- list(exclude.maximisation = c('baralpha', 'barb') )
parameter0[load.options$exclude.maximisation] <- parameter[load.options$exclude.maximisation]

res <- run(model, parameter0, init.options, SAEM.options,load.options, verbatim = 3)
saveRDS(res, paste0(params$rds_filename, '_', p, '.rds'))

# = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = # = = = = = = = = = #

plot(res, true.value = oracle, exclude = 'beta')
## [1] "SAEM execution time = 00h 28min 49sec"
## $plot_parameter

## 
## $plot_MCMC

## 
## $plot_acceptation

plot_high_dim(res, oracle, 'beta', zeta, dec = 0, 
              var.true$alpha, var.true$phi1, var.true$phi2, var.true$phi3, var.true$b)
## [[1]]

## 
## [[2]]

## 
## [[3]]

plot(res, true.value = oracle, var = 'summary', exclude = 'beta', time = F)
Result of the SAEM-MCMC
sigma2 mu.1 mu.2 mu.3 omega2.1 omega2.2 omega2.3 baralpha barb
Real value 0.0026 0.9032 89.9575 5.0079 0.0039 35.9179 0.6948 0.5 30
Estimated value 0.0025 0.9021 89.9561 4.9671 0.0044 35.4356 0.7924 0.5 30
Rrmse 0.0056 0.0012 0.0000 0.0081 0.1145 0.0134 0.1406 0.0 0